In [1]:
from IPython.display import Latex, Math

DCGAN - Introduction

In this post, I will try to explain the DCGAN paper

Before we do that, you may want to look at the idea of GANs and how they work. Here's one such blog. In short, Generator network tries to generate images that are similar to actual images in order to fool the Discriminator. This is the objective of Generator. Input to the Generator is a random vector $Z$ of size, say 100, from which Generator tries to produce a fake image. The problem GANs generally face is called Mode Collapse, where Generator always produces only one fake image from any input random vector $Z$ provided.

DCGAN tried addressing these issues and suggested the following things :

Architecture guidelines for stable Deep Convolutional GANs

• Replace any pooling layers with strided convolutions (discriminator) and fractional-strided convolutions (generator).

• Use batchnorm in both the generator and the discriminator.

• Remove fully connected hidden layers for deeper architectures.

• Use ReLU activation in generator for all layers except for the output, which uses Tanh. • Use LeakyReLU activation in the discriminator for all layers.

Data - CelebA

In [2]:
from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from pathlib import Path

# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
Random Seed:  999
Out[2]:
<torch._C.Generator at 0x7f8a780bb390>

alt-text

alt-text

In [3]:
datapath = Path('/media/mano/Data/comicgen/ComicGen_CycleGAN/celeba-dataset/img_align_celeba/img_align_celeba/')
batch_size = 128
weight_sd = 0.02
disc_leakyrelu_rate = 0.2

'''As specified in the DCGAN paper, both are Adam optimizers with learning rate 0.0002 and Beta1 = 0.5. '''
adam_lr = 0.0002
beta1 = 0.5
In [ ]:
'''
As the authors mentioned : zero-centered Normal distribution with standard deviation 0.02. we set our mean = 0 and sd = 0.02
'''
def params_init(m,standard_deviation=0.01):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

    

$y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta$

architecture

stridedconv

Batchnorm exclusion in GAN

batchnormexclusion

In [5]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        '''
        Input is 100 dim vector.
        Dense layer (100 x 4*4*1024) to get 4 * 4 * 1024 vector
        Reshape it to (4,4,1024)
        ConvTranspose to (8,8,512)
        ConvTranspose to (16,16,256)
        ConvTranspose to (32,32,128)
        ConvTranspose to (64,64,3)
        
        Conv2D formula for output size is 
        
        output = ( n + 2p - f )/s + 1 , where
        
        n = height, width of image
        p = padding
        f = feature size / kernel size
        s = stride
        
        Conv2D Transpose formula for output size is 
        
        output = ( n - 1 )*s + f - 2p
        '''
        self.inpdense = nn.Linear(100,4*4*1024)
        self.conv1 = nn.ConvTranspose2d(1024,512,kernel_size=4,stride=2,padding=1) # 4,4 to 8,8
        self.conv2 = nn.ConvTranspose2d(512,256,kernel_size=4,stride=2,padding=1)  # 8,8 to 16,16
        self.conv3 = nn.ConvTranspose2d(256,128,kernel_size=4,stride=2,padding=1)  # 16,16 to 32,32
        self.conv4 = nn.ConvTranspose2d(128,3,kernel_size=4,stride=2,padding=1)    # 32,32 to 64,64
        self.conv1_bn = nn.BatchNorm2d(512)
        self.conv2_bn = nn.BatchNorm2d(256)
        self.conv3_bn = nn.BatchNorm2d(128)
        self.relu     = nn.ReLU()
        self.tanh     = nn.Tanh()
        
    def forward(self,inp):
        inp = self.inpdense(inp)
        bz  = inp.shape[0] 
        inp = inp.view(bz,1024,4,4)
        inp = self.conv1(inp)
        inp = self.conv1_bn(inp)
        inp = self.relu(inp)
        inp = self.conv2(inp)
        inp = self.conv2_bn(inp)
        inp = self.relu(inp)
        inp = self.conv3(inp)
        inp = self.conv3_bn(inp)
        inp = self.relu(inp)
        inp = self.conv4(inp)
        inp = self.tanh(inp)
        return inp
In [6]:
generator = Generator()
generator
Out[6]:
Generator(
  (inpdense): Linear(in_features=100, out_features=16384, bias=True)
  (conv1): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv2): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv3): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv4): ConvTranspose2d(128, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv1_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2_bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3_bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU()
  (tanh): Tanh()
)

discriminator

ReLU for Generator

and

LeakyReLU for Discriminator

ActivationFunctions

In [7]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        '''
        As mention, LeakyReLU is to be used.
        Input is (64,64,3)
        Conv to (32,32,128)
        Conv to (16,16,256)
        Conv to (8,8,512)
        Conv to (4,4,1024)
        Sigmoid (For binary classification)
        '''
        self.conv1    = nn.Conv2d(3,128,kernel_size=4,stride=2,padding=1) # (64,64,3) to (32,32,128) 
        self.conv2    = nn.Conv2d(128,256,kernel_size=4,stride=2,padding=1) # (32,32,128) to (16,16,256)
        self.conv3    = nn.Conv2d(256,512,kernel_size=4,stride=2,padding=1) # (16,16,256) to (8,8,512)
        self.conv4    = nn.Conv2d(512,1024,kernel_size=4,stride=2,padding=1) # (8,8,512) to (4,4,1024)
        self.conv5    = nn.Conv2d(1024,1,kernel_size=4,stride=1,padding=0) # (4,4,1024) to (1,1,1) 
        self.lrelu    = nn.LeakyReLU(0.2)
        self.sigmoid  = nn.Sigmoid()
        self.conv2_bn = nn.BatchNorm2d(256)
        self.conv3_bn = nn.BatchNorm2d(512)
        self.conv4_bn = nn.BatchNorm2d(1024)
        
    def forward(self,inp):
        inp = self.conv1(inp)
        inp = self.lrelu(inp)
        inp = self.conv2(inp)
        inp = self.conv2_bn(inp)
        inp = self.lrelu(inp)
        inp = self.conv3(inp)
        inp = self.conv3_bn(inp)
        inp = self.lrelu(inp)
        inp = self.conv4(inp)
        inp = self.conv4_bn(inp)
        inp = self.lrelu(inp)
        inp = self.conv5(inp)
        inp = self.sigmoid(inp)
        return inp
In [8]:
discriminator = Discriminator()
discriminator
Out[8]:
Discriminator(
  (conv1): Conv2d(3, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv2): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv3): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv4): Conv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv5): Conv2d(1024, 1, kernel_size=(4, 4), stride=(1, 1))
  (lrelu): LeakyReLU(negative_slope=0.2)
  (sigmoid): Sigmoid()
  (conv2_bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4_bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

Loss Functions

$BCEloss = -(y*\log{p} + (1-y)*\log{(1-p)}) $

For a binary classification problem, y = {0,1}. If we say, P is the probability of an instance belonging to 1st class.Then,

For 0th class, $loss = - ( 0*\log{p} + (1-0)*\log{(1-p)} )$ ==> $-log(1-p)$

For 1st class, $loss = - ( 1*\log{p} + (1-1)*\log{(1-p)} )$ ==> $-log(p)$

We can use the same loss function for training both Discriminators and Generators. How ?

Probability of finding object to real be $P(D($x$))$.

Our loss function for Discriminator, given that $x$ is a real object and $G(z)$ is a fake object, should maximize $P(D($x$))$ and reduce P(D(G(z))), equal to $P(1-D(G(z)))$

Loss function for Generator is maximize $P(D(G(z)))$ and reduce $P(D(x))$

In [9]:
# Initialize BCELoss function
criterion = nn.BCELoss()


# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(discriminator.parameters(), lr=adam_lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(generator.parameters(), lr=adam_lr, betas=(beta1, 0.999))
In [10]:
# Size of z latent vector (i.e. size of generator input)
nz = 100

Data

In [11]:
import multiprocessing
In [12]:
dataroot = '/media/mano/Data/comicgen/ComicGen_CycleGAN/celeba-dataset/img_align_celeba/'
image_size = 64
# workers = 2
workers = multiprocessing.cpu_count()
ngpu = 1
# We can use an image folder dataset the way we have it setup.
# Create the dataset
dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)


# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
Out[12]:
<matplotlib.image.AxesImage at 0x7f8a1e590f98>
In [13]:
# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(64, nz, device=device)

# Establish convention for real and fake labels during training
real_label = 1
fake_label = 0

Moving models to GPU

In [14]:
generator = generator.to(device)
discriminator = discriminator.to(device)

Training

In [15]:
num_epochs = 3
In [16]:
# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        discriminator.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, device=device)
        # Forward pass real batch through D
        output = discriminator(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, device=device)
        # Generate fake image batch with G
        fake = generator(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = discriminator(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        generator.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = discriminator(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = generator(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1
Starting Training Loop...
[0/3][0/1583]	Loss_D: 1.4355	Loss_G: 5.2351	D(x): 0.4986	D(G(z)): 0.5081 / 0.0058
[0/3][50/1583]	Loss_D: 0.1744	Loss_G: 8.5249	D(x): 0.9513	D(G(z)): 0.1081 / 0.0003
[0/3][100/1583]	Loss_D: 0.7275	Loss_G: 2.2964	D(x): 0.6218	D(G(z)): 0.1578 / 0.1238
[0/3][150/1583]	Loss_D: 0.6750	Loss_G: 2.1016	D(x): 0.6417	D(G(z)): 0.1534 / 0.1449
[0/3][200/1583]	Loss_D: 1.4418	Loss_G: 4.4716	D(x): 0.9349	D(G(z)): 0.6988 / 0.0238
[0/3][250/1583]	Loss_D: 0.7467	Loss_G: 2.1327	D(x): 0.6949	D(G(z)): 0.2693 / 0.1477
[0/3][300/1583]	Loss_D: 1.3055	Loss_G: 1.3813	D(x): 0.3374	D(G(z)): 0.0763 / 0.2986
[0/3][350/1583]	Loss_D: 0.8965	Loss_G: 2.1935	D(x): 0.7929	D(G(z)): 0.4460 / 0.1364
[0/3][400/1583]	Loss_D: 0.9217	Loss_G: 1.8440	D(x): 0.6074	D(G(z)): 0.2788 / 0.1879
[0/3][450/1583]	Loss_D: 0.8203	Loss_G: 1.8469	D(x): 0.6876	D(G(z)): 0.3095 / 0.1835
[0/3][500/1583]	Loss_D: 0.7411	Loss_G: 3.3521	D(x): 0.8333	D(G(z)): 0.3960 / 0.0427
[0/3][550/1583]	Loss_D: 0.9328	Loss_G: 2.2314	D(x): 0.6835	D(G(z)): 0.3750 / 0.1278
[0/3][600/1583]	Loss_D: 0.6484	Loss_G: 2.4441	D(x): 0.7133	D(G(z)): 0.1751 / 0.1209
[0/3][650/1583]	Loss_D: 0.6654	Loss_G: 3.2170	D(x): 0.8452	D(G(z)): 0.3647 / 0.0495
[0/3][700/1583]	Loss_D: 0.5841	Loss_G: 3.5390	D(x): 0.8663	D(G(z)): 0.3348 / 0.0371
[0/3][750/1583]	Loss_D: 0.5205	Loss_G: 2.5762	D(x): 0.8265	D(G(z)): 0.2554 / 0.0911
[0/3][800/1583]	Loss_D: 0.5796	Loss_G: 2.4675	D(x): 0.7379	D(G(z)): 0.2089 / 0.1010
[0/3][850/1583]	Loss_D: 0.8191	Loss_G: 1.9125	D(x): 0.6999	D(G(z)): 0.3310 / 0.1720
[0/3][900/1583]	Loss_D: 0.6805	Loss_G: 3.9284	D(x): 0.8648	D(G(z)): 0.3843 / 0.0241
[0/3][950/1583]	Loss_D: 1.3082	Loss_G: 4.8423	D(x): 0.8615	D(G(z)): 0.6339 / 0.0122
[0/3][1000/1583]	Loss_D: 0.4357	Loss_G: 3.0763	D(x): 0.8752	D(G(z)): 0.2397 / 0.0577
[0/3][1050/1583]	Loss_D: 0.9288	Loss_G: 2.4834	D(x): 0.8233	D(G(z)): 0.4812 / 0.1036
[0/3][1100/1583]	Loss_D: 0.5955	Loss_G: 2.1405	D(x): 0.7637	D(G(z)): 0.2482 / 0.1435
[0/3][1150/1583]	Loss_D: 0.3771	Loss_G: 3.0417	D(x): 0.8460	D(G(z)): 0.1745 / 0.0589
[0/3][1200/1583]	Loss_D: 0.6567	Loss_G: 2.5074	D(x): 0.8360	D(G(z)): 0.3473 / 0.1002
[0/3][1250/1583]	Loss_D: 0.3145	Loss_G: 2.5114	D(x): 0.7906	D(G(z)): 0.0576 / 0.1004
[0/3][1300/1583]	Loss_D: 0.6113	Loss_G: 2.6613	D(x): 0.8461	D(G(z)): 0.3339 / 0.0836
[0/3][1350/1583]	Loss_D: 0.0284	Loss_G: 4.0064	D(x): 0.9939	D(G(z)): 0.0219 / 0.0238
[0/3][1400/1583]	Loss_D: 0.4486	Loss_G: 2.9393	D(x): 0.7752	D(G(z)): 0.1436 / 0.0665
[0/3][1450/1583]	Loss_D: 0.8425	Loss_G: 3.6469	D(x): 0.8830	D(G(z)): 0.4827 / 0.0333
[0/3][1500/1583]	Loss_D: 0.4596	Loss_G: 3.0065	D(x): 0.7090	D(G(z)): 0.0588 / 0.0762
[0/3][1550/1583]	Loss_D: 0.8553	Loss_G: 1.8704	D(x): 0.5052	D(G(z)): 0.0595 / 0.2026
[1/3][0/1583]	Loss_D: 0.3524	Loss_G: 2.6749	D(x): 0.8662	D(G(z)): 0.1653 / 0.0897
[1/3][50/1583]	Loss_D: 0.4313	Loss_G: 4.7481	D(x): 0.9282	D(G(z)): 0.2802 / 0.0110
[1/3][100/1583]	Loss_D: 0.3145	Loss_G: 3.6425	D(x): 0.9233	D(G(z)): 0.1946 / 0.0338
[1/3][150/1583]	Loss_D: 0.4167	Loss_G: 2.8152	D(x): 0.9161	D(G(z)): 0.2576 / 0.0758
[1/3][200/1583]	Loss_D: 0.0875	Loss_G: 3.4893	D(x): 0.9688	D(G(z)): 0.0529 / 0.0411
[1/3][250/1583]	Loss_D: 0.4749	Loss_G: 3.3456	D(x): 0.8226	D(G(z)): 0.1991 / 0.0474
[1/3][300/1583]	Loss_D: 0.2608	Loss_G: 2.9768	D(x): 0.9001	D(G(z)): 0.1313 / 0.0678
[1/3][350/1583]	Loss_D: 0.4739	Loss_G: 5.7619	D(x): 0.6567	D(G(z)): 0.0015 / 0.0048
[1/3][400/1583]	Loss_D: 3.6565	Loss_G: 5.9052	D(x): 0.0403	D(G(z)): 0.0001 / 0.0039
[1/3][450/1583]	Loss_D: 0.4639	Loss_G: 3.8150	D(x): 0.9652	D(G(z)): 0.3299 / 0.0277
[1/3][500/1583]	Loss_D: 0.0096	Loss_G: 5.7952	D(x): 0.9941	D(G(z)): 0.0036 / 0.0039
[1/3][550/1583]	Loss_D: 0.4371	Loss_G: 7.0130	D(x): 0.9726	D(G(z)): 0.3099 / 0.0012
[1/3][600/1583]	Loss_D: 0.0754	Loss_G: 5.1592	D(x): 0.9355	D(G(z)): 0.0067 / 0.0076
[1/3][650/1583]	Loss_D: 0.1134	Loss_G: 3.2303	D(x): 0.9650	D(G(z)): 0.0723 / 0.0495
[1/3][700/1583]	Loss_D: 0.2313	Loss_G: 3.1892	D(x): 0.8746	D(G(z)): 0.0834 / 0.0549
[1/3][750/1583]	Loss_D: 0.0374	Loss_G: 7.1415	D(x): 0.9643	D(G(z)): 0.0005 / 0.0010
[1/3][800/1583]	Loss_D: 0.7910	Loss_G: 3.6535	D(x): 0.5103	D(G(z)): 0.0183 / 0.0431
[1/3][850/1583]	Loss_D: 0.6752	Loss_G: 2.5005	D(x): 0.7062	D(G(z)): 0.2352 / 0.1041
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-16-919369d95c92> in <module>
     59         # Calculate gradients for G
     60         errG.backward()
---> 61         D_G_z2 = output.mean().item()
     62         # Update G
     63         optimizerG.step()

KeyboardInterrupt: 

Visualize the generated images. Alright! They are improving. Training atleast three epochs would be better.

In [31]:
for i in range(len(img_list)):
    plt.figure(figsize=(8,8))
    plt.axis("off")
    plt.title("Training Images")
    plt.imshow(np.transpose(img_list[i].cpu(),(1,2,0)))
In [23]:
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(img_list[4].cpu(),(1,2,0)))
Out[23]:
<matplotlib.image.AxesImage at 0x7f8a23cc3e80>

More on batchnorm

explainBN

In [ ]: